import torch
from torch import nn
from transformers import Trainer
import torch.nn.functional as F
import copy, os
import deepspeed
from evaluate_util import get_dataloader, get_all_evals, get_all_evals_rouge
import copy
import json 
from pathlib import Path
from data_module import get_batch_loss 
from utils import merge_dicts, interleave_eval_result_dict, get_forget_quality, get_model_utility
import numpy as np
from scipy.stats import ks_2samp, hmean
import csv 
from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available

def printll(name, inp):
    #print list with 4 decimal for each item
    print(name, [round(x, 4) for x in inp])

class CustomLoss(nn.Module):
    def __init__(self, ignore_index=-100, epsilon=0.1):
        super().__init__()
        self.ignore_index = ignore_index
        self.epsilon = epsilon

    def __call__(self, model_output, labels, shift_labels=False, coefs=None):
        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
        
        # Shift labels if needed
        if shift_labels:
            logits = logits[..., :-1, :].contiguous()  # Remove last logit
            labels = labels[..., 1:].contiguous()      # Remove first label

        # import pdb; pdb.set_trace();

        log_probs = -nn.functional.log_softmax(logits, dim=-1)
        
        if labels.dim() == log_probs.dim() - 1:
            labels = labels.unsqueeze(-1)  # Add dimension to labels
        
        padding_mask = labels.eq(self.ignore_index)

        # Clamp labels to avoid index errors
        labels = torch.clamp(labels, min=0)

        nll_loss = log_probs.gather(dim=-1, index=labels)
        
        # Handle fp16 inputs
        smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)

        nll_loss.masked_fill_(padding_mask, 0.0)
        smoothed_loss.masked_fill_(padding_mask, 0.0)

        num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        coefs = coefs.unsqueeze(1)
        nll_loss = nll_loss*coefs
        smoothed_loss = smoothed_loss*coefs
        nll_loss = nll_loss.sum() / num_active_elements
        smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])

        # import pdb; pdb.set_trace();
        
        return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss



class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids, labels, attention_mask = inputs
        # forward pass
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        # logits = outputs.get("logits")
        loss = outputs.loss
        # # compute custom loss (suppose one has 3 labels with different weights)
        # loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device))
        # loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)
    


class CustomTrainerForgetting(Trainer):
    def __init__(self, *args, **kwargs):
        self.loss_type = kwargs.pop('forget_loss')
        self.oracle_model = kwargs.pop('oracle_model')
        self.eval_cfg = kwargs.pop('eval_cfg')

        super(CustomTrainerForgetting, self).__init__(*args, **kwargs)
        if self.loss_type == "KL":
            self.oracle_model = self.e_prepare_deepspeed(self.oracle_model)

    def e_prepare_deepspeed(self, model):
        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = self.accelerator.state.deepspeed_plugin
        config_kwargs = copy.deepcopy(deepspeed_plugin.deepspeed_config)

        if model is not None:
            if hasattr(model, "config"):
                hidden_size = (
                    max(model.config.hidden_sizes)
                    if getattr(model.config, "hidden_sizes", None)
                    else getattr(model.config, "hidden_size", None)
                )
                if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                    # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                    # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                    config_kwargs.update(
                        {
                            "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                            "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                            "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                        }
                    )

        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        config_kwargs["optimizer"] = {"type": None}
        model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        model.eval()
        #set the gradients to false for every parameter
        for param in model.parameters():
            param.requires_grad = False
        
        return model

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        idk_inputs, retain_inputs = inputs
        idk_input_ids, idk_labels, idk_attention_mask, idk_coefs = idk_inputs
        retain_input_ids, retain_labels, retain_attention_mask, retain_coefs = retain_inputs
            
        #concatenate the inputs. single forward pass is much more efficient
        input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
        labels = torch.cat((idk_labels, retain_labels), dim=0)
        attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)
        coefs = torch.cat((idk_coefs, retain_coefs), dim=0)
        # import pdb; pdb.set_trace();
            
        outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = self.accelerator.unwrap_model(model)
            if True:
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            # User-defined compute_loss function
            if self.compute_loss_func is not None:
                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
            elif True: #model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = CustomLoss()(outputs, labels, True, coefs)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss

    

    # def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    #     if self.loss_type == "grad_ascent":
    #         forget_inputs, retain_inputs = inputs
    #         input_ids, labels, attention_mask = forget_inputs
    #         outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
    #         forget_loss = outputs.loss
    #         forget_loss = forget_loss * -1
    #         loss = forget_loss
        
    #     elif self.loss_type == "switch":
    #         # import pdb; pdb.set_trace();
    #         idk_inputs, retain_inputs = inputs
    #         idk_input_ids, idk_labels, idk_attention_mask, idk_coefs = idk_inputs
    #         retain_input_ids, retain_labels, retain_attention_mask, retain_coefs = retain_inputs
            
    #         #concatenate the inputs. single forward pass is much more efficient
    #         input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
    #         labels = torch.cat((idk_labels, retain_labels), dim=0)
    #         attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)
    #         # import pdb; pdb.set_trace();
    #         coefs = torch.cat((idk_coefs, retain_coefs), dim=0)
            
    #         outputs = model(input_ids,labels=labels, attention_mask=attention_mask, reduction="none")
    #         logits = outputs.logits.view(-1, outputs.logits.size(-1))  # Shape: (batch_size * seq_length, vocab_size)
    #         labels_flat = labels.view(-1) 
    #         # Calculate per-token loss with CrossEntropyLoss
    #         loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")  # Ensure padding tokens are ignored
    #         per_token_loss = loss_fn(logits, labels_flat)  # Shape: (batch_size * seq_length)

    #         # Reshape per-token loss back to (batch_size, seq_length)
    #         per_token_loss = per_token_loss.view(labels.size(0), -1)  # Shape: (batch_size, seq_length)

    #         # Mask out padding tokens and compute mean loss per sample
    #         masked_loss = per_token_loss * attention_mask  # Apply attention mask
    #         loss_per_sample = masked_loss.sum(dim=1) / attention_mask.sum(dim=1)  # Mean across valid tokens per sample

    #         # Apply coefficients to each sample's average loss
    #         # Ensure coefs is the same size as the number of samples in the batch
    #         weighted_loss_per_sample = loss_per_sample * coefs
    #         import pdb; pdb.set_trace();

    #         print(coefs)


    #         # import pdb; pdb.set_trace();
    #         loss = weighted_loss_per_sample.mean()

    #         # loss = outputs.loss


    #     elif self.loss_type == "grad_diff":
    #         forget_inputs, retain_inputs = inputs
    #         input_ids, labels, attention_mask = forget_inputs
    #         outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
    #         forget_loss = outputs.loss
    #         forget_loss = forget_loss * -1

    #         retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
    #         retain_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
    #         retain_loss = retain_outputs.loss
    #         loss = forget_loss + retain_loss
        
    #     elif self.loss_type == "KL":
    #         forget_inputs, retain_inputs = inputs
    #         input_ids, labels, attention_mask = forget_inputs
    #         outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
    #         forget_loss = outputs.loss
    #         forget_loss = forget_loss * -1
    #         retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
    #         with torch.no_grad():
    #             retain_outputs = self.oracle_model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
            
    #         retain_probs = F.log_softmax(retain_outputs.logits, dim=-1)
    #         retain_probs = retain_probs.view(-1, retain_outputs.logits.shape[-1])

    #         current_outputs = model(retain_input_ids,labels=retain_labels, attention_mask=retain_attention_mask)
    #         current_probs = F.log_softmax(current_outputs.logits, dim=-1)
    #         current_probs = current_probs.view(-1, current_outputs.logits.shape[-1])

    #         #minimum KL divergence
    #         retain_loss = nn.functional.kl_div(current_probs, retain_probs, reduction='batchmean', log_target=True)
    #         loss = forget_loss + retain_loss

    #     elif self.loss_type == "idk":
    #         idk_inputs, retain_inputs = inputs
    #         idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
    #         retain_input_ids, retain_labels, retain_attention_mask = retain_inputs
            
    #         #concatenate the inputs. single forward pass is much more efficient
    #         input_ids = torch.cat((idk_input_ids, retain_input_ids), dim=0)
    #         labels = torch.cat((idk_labels, retain_labels), dim=0)
    #         attention_mask = torch.cat((idk_attention_mask, retain_attention_mask), dim=0)
            
    #         outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
    #         loss = outputs.loss
        
    #     elif self.loss_type == "dpo":
    #         idk_inputs, forget_inputs, retain_inputs = inputs
    #         idk_input_ids, idk_labels, idk_attention_mask = idk_inputs
    #         forget_input_ids, forget_labels, forget_attention_mask = forget_inputs
    #         idk_outputs = model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
    #         forget_outputs = model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)

    #         with torch.no_grad():
    #             idk_outputs_oracle = self.oracle_model(idk_input_ids,labels=idk_labels, attention_mask=idk_attention_mask)
    #             forget_outputs_oracle = self.oracle_model(forget_input_ids,labels=forget_labels, attention_mask=forget_attention_mask)
    #             idk_logits_oracle = idk_outputs_oracle.logits
    #             forget_logits_oracle = forget_outputs_oracle.logits

    #         idk_loss_oracle = -1 * get_batch_loss(idk_logits_oracle, idk_labels)
    #         forget_loss_oracle = -1 * get_batch_loss(forget_logits_oracle, labels)
            
    #         idk_loss_current = -1 * get_batch_loss(idk_outputs.logits, idk_labels)
    #         forget_loss_current = -1 * get_batch_loss(forget_outputs.logits, labels)


    #         pi_logratios = idk_loss_current - forget_loss_current
    #         ref_logratios = idk_loss_oracle - forget_loss_oracle

    #         beta = 0.1
    #         loss = -F.logsigmoid(beta * (pi_logratios - ref_logratios)).mean()
    #         print(loss.item())
    #         loss = -pi_logratios.mean()
    #         loss = -idk_loss_current.mean()

    #         outputs = forget_outputs
        
    #     return (loss, outputs) if return_outputs else loss
        
    
    def prediction_step(self, model, inputs, prediction_loss_only: bool, ignore_keys=None):
        input_ids, labels, attention_mask = inputs
        # forward pass
        with torch.no_grad():
            outputs = model(input_ids,labels=labels, attention_mask=attention_mask)
            logits = outputs.logits
            loss = outputs.loss
        return (loss, logits, labels)

    def evaluate(
        self,
        eval_dataset = None,
        ignore_keys = None,
        metric_key_prefix = "eval",
    ):
        # if eval is called w/o train, handle model prep here
        if self.is_deepspeed_enabled and self.deepspeed is None:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
        args = self.args
        model = self._wrap_model(self.model, training=False, dataloader=None)
        print(self.is_in_train, args.device, model.dtype, self.args.dataloader_num_workers, self.eval_cfg.split_list, self.eval_cfg.split)
        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            # for the rest of this function `model` is the outside model, whether it was wrapped or not
            if model is not self.model:
                self.model_wrapped = model

            # backward compatibility
            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)
        model.eval()
        curr_step = self.state.global_step
        eval_cfg = self.eval_cfg

        curr_save_dir = os.path.join(eval_cfg.save_dir, f"checkpoint-{curr_step}")
        Path(curr_save_dir).mkdir(parents=True, exist_ok=True)
        forget_rate = eval_cfg.split.split('_')[0]
        with torch.no_grad():
            for i, (folder, split, question_key, answer_key, eval_task, base_answer_key, perturbed_answer_key) in enumerate(zip(eval_cfg.data_path, eval_cfg.split_list, eval_cfg.question_key, eval_cfg.answer_key, eval_cfg.eval_task, eval_cfg.base_answer_key, eval_cfg.perturbed_answer_key)):
                world_size = self.accelerator.num_processes

                # For some reason, Hydra is not interprating the split correctly
                if eval_task == 'eval_log_forget':
                    split = eval_cfg.split
                print(f'Working on eval task {eval_task} with split {split}')
                save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                save_filename = save_filename if world_size == 1 else os.path.join(curr_save_dir, f"{eval_task}_{self.accelerator.local_process_index}.json")
                # print(save_filename)
                if os.path.exists(save_filename) and not eval_cfg.overwrite:
                    print(f"Skipping {eval_task} because {save_filename} already exists")
                    continue

                eval_dataloader, base_eval_dataloader, perturb_dataloader = get_dataloader(eval_cfg, eval_task, self.tokenizer, folder, split, question_key, answer_key, base_answer_key, perturbed_answer_key)
                eval_dataloader = self.accelerator.prepare(eval_dataloader)
                # print('dataset condition: ', len(eval_dataloader.dataset), self.accelerator.local_process_index)
                base_eval_dataloader = self.accelerator.prepare(base_eval_dataloader)
                perturb_dataloader = self.accelerator.prepare(perturb_dataloader)
                normalize_gt = False 
                # if 'eval_log' not in eval_task:
                #     normalize_gt = True

                eval_logs = get_all_evals_rouge(eval_cfg, model, self.tokenizer, eval_task, eval_dataloader, base_eval_dataloader, perturb_dataloader, normalize_gt=normalize_gt)

                import pdb; pdb.set_trace();
                with open(save_filename, "w") as f:
                    # pretty write json to f
                    json.dump(eval_logs, f, indent=4)
            
                #wait for all process to finish
            self.accelerator.wait_for_everyone()
            aggregated_eval_logs = {}
            for eval_task in eval_cfg.eval_task:
                #read the saved file as json and merge them using merge_dicts
                if world_size > 1:
                    if self.accelerator.is_local_main_process:
                        eval_logs = json.load(open(os.path.join(curr_save_dir, f"{eval_task}_0.json")))
                        for i in range(1, world_size):
                            filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                            eval_logs = merge_dicts(eval_logs, json.load(open(filename)))
                        
                        aggregated_eval_logs[f'{eval_task}.json'] = eval_logs

                        new_save_filename = os.path.join(curr_save_dir, f"{eval_task}.json")
                        with open(new_save_filename, "w") as f:
                            # pretty write json to f
                            json.dump(eval_logs, f, indent=4)

                            #delete old files use shutil

                            for i in range(world_size):
                                filename = os.path.join(curr_save_dir, f"{eval_task}_{i}.json")
                                os.remove(filename)
                                
            if self.accelerator.is_local_main_process:
                # aggregated_eval_logs = interleave_eval_result_dict(aggregated_eval_logs, forget_rate, large_bsz=eval_cfg.batch_size, num_processes=world_size)
                aggregated_eval_log_filename = os.path.join(curr_save_dir, "eval_log_aggregated.json")

                with open(aggregated_eval_log_filename, 'w') as f:
                    json.dump(aggregated_eval_logs, f, indent=4)

                if eval_cfg.retain_result is not None:
                    model_utility = get_model_utility(aggregated_eval_logs)
                    retain_result = json.load(open(eval_cfg.retain_result, 'r'))
                    forget_quality = get_forget_quality(aggregated_eval_logs, retain_result)
                    aggregate_stat = {**model_utility, **forget_quality}

                    # save aggregate_stat as csv
                    with open(os.path.join(curr_save_dir, "aggregate_stat.csv"), 'w') as csvfile:
                        field_names = list(aggregate_stat.keys())
                        writer = csv.DictWriter(csvfile, fieldnames=field_names)
                        writer.writeheader()
                        writer.writerow(aggregate_stat)

def custom_data_collator_forget(samples):
    forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
    rets = []
    for data_type in ["forget", "retain"]:
        data = forget_samples if data_type == "forget" else retain_samples
        input_ids = [s[0] for s in data]
        labels = [s[1] for s in data]
        attention_mask = [s[2] for s in data]
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask)))
    return rets

def custom_data_collator_forget_switch(samples):
    forget_samples, retain_samples = [sample[0] for sample in samples], [sample[1] for sample in samples]
    rets = []
    # import pdb; pdb.set_trace();
    for data_type in ["forget", "retain"]:
        data = forget_samples if data_type == "forget" else retain_samples
        input_ids = [s[0] for s in data]
        labels = [s[1] for s in data]
        attention_mask = [s[2] for s in data]
        coefs = [s[3] for s in data]
        rets.append((torch.stack(input_ids), torch.stack(labels), torch.stack(attention_mask), torch.stack(coefs)))
    return rets



def compute_metrics(pred):
    logits, labels = torch.from_numpy(pred.predictions), torch.from_numpy(pred.label_ids)
    preds = torch.from_numpy(pred.predictions.argmax(-1))
    shifted_labels = labels[..., 1:].contiguous()
    acc = torch.mean((preds[..., :-1] == shifted_labels).float())
    loss  = get_loss(logits, labels)
    print("computed")
    return {"eval accuracy": acc, "eval loss": loss.item()}

def get_loss(output, labels):
    shifted_labels = labels[..., 1:].contiguous()
    output = output[..., :-1, :].contiguous()

    loss_function = nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_function(output.view(-1, output.size(-1)), shifted_labels.view(-1))

    return loss


